import sys
import pandas as pd
import pickle
from tqdm import tqdm
import random
import numpy as np
import time


from Program_2 import *

from Program_1 import *

foi_keys = fields_of_interest.keys()
conditions = {
    'gender': {
        1: "男性",
        2: "女性"
    },
    'nation': {
        1: "汉族",
        (2, 3, 4, 5, 6, 7, 8): "少数民族"
    },
    'Religious belief': {
        1: "不信仰宗教",
        (11, 12, 14, 15, 16): "五大宗教",
        (13, 17, 18, 19, 20, 21): "其他宗教",
    },
    'degree': {
         (1, 14): "没有受过任何教育",
         (2, 3, 4, 5, 6, 7, 8, 9, 10): "本科以下",
         (11, 12, 13): "本科及以上"
    },
    'Total revenue in 2020': {
          range(0, 30000): "低收入",
          range(30000, 60000): "中低收入",
          range(60000, 150000): "中高收入",
          range(150000, 100000000): "高收入"
    },
    'Political status': {
          1: "群众, 不知道, 拒绝回答",
          2: "共青团员",
          3: "民主党派",
          4: "共产党员"
    }
}

def compute_demographic_distribution(df):
    distributions = {}
    for key in fields_of_interest.keys():
        value_counts = df[key].value_counts(normalize=True).to_dict()
        distributions[key] = value_counts
    return distributions

def generate_fake_respondent(distributions):
    fake_respondent = {}
    for k, v in distributions.items():
        fake_respondent[k] = np.random.choice(list(v.keys()), p=list(v.values()))
    return fake_respondent

def gen_backstory_from_fake_person(fake_person):
    backstory = ""
    for k, anes_val in fake_person.items():
        if anes_val in ['#NULL!', None] or pd.isna(anes_val):
            continue
        elem_template = fields_of_interest[k]['template']
        elem_map = fields_of_interest[k]['valmap']
        if len(elem_map) == 0:
            backstory += " " + elem_template.replace('XXX', str(anes_val))
        elif anes_val in elem_map:
            backstory += " " + elem_template.replace('XXX', elem_map[anes_val])
    if backstory[0] == ' ':
        backstory = backstory[1:]
    return backstory

def generate_query_with_backstory(backstory, question):
    return f"{backstory}. {question}"

def filter_dataframe(df, value_set, column):
    if isinstance(value_set, tuple):
        mask = df[column].isin(value_set)
    elif isinstance(value_set, range):
        mask = df[column].between(value_set.start, value_set.stop - 1)
    else:
        mask = (df[column] == value_set)
    return df[mask]

anesdf = pd.read_csv(ANES_FN, sep=SEP, encoding='gbk', low_memory=False)
anes_2020_questionnaire = pd.read_csv("C:/Users/16150/Desktop/东亚/2021年/ANES_2020_multiple_questions_selected.csv")

def generate_prompt_for_question(question, answers):
    return f"问题: {question}\n\n回答选项:\n{answers}\n\n我的回答是\n"



# Define the index range this script should process
# [0 ~ 9]
START_INDEX = 4  # The starting index for this iteration (inclusive)
END_INDEX = 4    # The ending index for this iteration (inclusive)
time_date = "今天是2021年6月1日."


for idx in range(START_INDEX, END_INDEX + 1):
    row = anes_2020_questionnaire.iloc[idx]
    code = row["Code"]
    question = row["Question"]
    answers = row["Answers"]
    
    
    # index for demographic conditions
    # [0 ~ 7]
    CONDITION_START_INDEX = 0
    CONDITION_END_INDEX = 7


    for cond_idx, (column, value_sets) in enumerate(conditions.items()):
        if cond_idx < CONDITION_START_INDEX or cond_idx > CONDITION_END_INDEX:
            continue 

        
        for value_set, label in value_sets.items():
            if isinstance(value_set, tuple) or isinstance(value_set, range):
                value_set_description = f"{' '.join(map(str, value_set))}"
            else:
                value_set_description = str(value_set)
            print(f"Running experiment for {column} with condition: {label} for values {value_set_description}")
            
            filtered_df = filter_dataframe(anesdf, value_set, column)            
            full_results = []
            distributions = compute_demographic_distribution(filtered_df)

            MAX_RETRIES = 5
            for idx in tqdm(range(len(filtered_df))):

                fake_person = generate_fake_respondent(distributions)
                backstory = gen_backstory_from_fake_person(fake_person)
                user_prompt = generate_prompt_for_question(question, answers)
                system_prompt = time_date + backstory

                full_prompt = generate_query_with_backstory(system_prompt, user_prompt)                
                fake_id = f"fake_{idx}"  
                retries = 0
                success = False
                while not success and retries < MAX_RETRIES:
                    try:
                        response = do_query(system_prompt, user_prompt)
                        result_entry = (fake_id, *fake_person.values(), full_prompt, response)
                        full_results.append(result_entry)
                        success = True  
                    except openai.APIConnectionError as e:
                        print("The server could not be reached")
                        print(e.__cause__)  # an underlying Exception, likely raised within httpx.
                    except openai.RateLimitError as e:
                        print("A 429 status code was received; we should back off a bit.")
                    except openai.APIStatusError as e:
                        print("Another non-200-range status code was received")
                        print(e.status_code)
                        print(e.response)
                    except openai.OpenAIError as e:
                        print(f"API Error: {e}")
                    retries += 1  

                if not success:
                    print(f"Failed to get a response after {MAX_RETRIES} retries for respondent {fake_id}.")

                print(full_prompt)
                print(response)
              
            # Save the results
            columns = ["ID", *fields_of_interest.keys(), "Prompt", "Response"]
            df_results = pd.DataFrame(full_results, columns=columns)
            output_csv = f"output_{code}_{label}.csv"
            df_results.to_csv(output_csv, index=False)
